import sys, time, os
import numpy as np
import argparse
import warnings
warnings.filterwarnings("ignore")
from tqdm.auto import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# os.environ["CUDA_VISIBLE_DEVICES"]="2" # set the device as here.
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
# print(device)

import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AdamW
from sentence_transformers import SentenceTransformer
from datasets import Dataset as HFDataset

from load_data import *
from utils import *
from retriever import Retriever

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cora',
                    choices=['cora', 'pubmed', 'ogbn-arxiv', 'ogbn-products'])
parser.add_argument('--max_length', type=int, default=512)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--lr_llm', type=float, default=1e-4)
parser.add_argument('--lr_retriever', type=float, default=1e-5)
parser.add_argument('--input_mode', type=str, default='t_rag_tape',
                    choices=['t_rag_x', 'c_rag_x', 'tc_rag_x', 'ct_rag_x', 
                            't_rag_tape', 'c_rag_tape', 'tc_rag_tape', 'ct_rag_tape',
                            't_ragt_tape', 'c_ragt_tape', 'tc_ragt_tape', 'ct_ragt_tape'])
parser.add_argument('--input_label_pool', type=int, default=2,
                    choices=[0,1,2,3], help='Number of labels to choose from the prediction of teacher GNNs')
parser.add_argument('--output_mode', type=str, default='g_x',
                    choices=['g_x', 'd_x', 'p_x', 'g_r', 'p_r', 'g_2l'])
parser.add_argument('--retrieved_neighbors', type=int, default=4, help='Number of neighbors from semantic retrieval')
parser.add_argument('--train_neighbors', type=int, default=1, help='Number of augmented samples per training sample used for training the LLM')
parser.add_argument('--eval_neighbors', type=int, default=1, help='Number of augmented samples per training sample used for training the LLM')
parser.add_argument('--num_neighbors', type=int, default=5, help='Number of neighbors from PPR')
parser.add_argument('--model_id', type=str, default='google/flan-t5-small',
                    choices=['google/flan-t5-small',
                            'google/flan-t5-base',
                            'google/flan-t5-large'])
parser.add_argument('--retriever_id', type=str, default='sentence-transformers/all-MiniLM-L6-v2',
                    choices=['sentence-transformers/sentence-t5-base','sentence-transformers/all-MiniLM-L6-v2',])
parser.add_argument('--verbose', action='store_true', default=True)
parser.add_argument('--no-verbose', action='store_false', dest='verbose', help='Disable verbose output')
parser.add_argument('--train_retriever', action='store_true', default=True)
parser.add_argument('--no-train_retriever', action='store_false', dest='train_retriever', help='Disable verbose output')
parser.add_argument('--eval_per_steps', type=int, default=50)
parser.add_argument('--patience', type=int, default=3, help='Early stopping patience')

args = parser.parse_args()

def train(model, optimizer, batch):
    model.train()
    optimizer.zero_grad()
    inputs, outputs = batch
    inputs = tokenizer(inputs, padding='max_length', truncation=True, max_length=args.max_length, return_tensors='pt').to(device)
    labels = tokenizer(outputs, padding='max_length', truncation=True, max_length=args.max_length, return_tensors='pt').to(device)
    labels["input_ids"][labels["input_ids"] == tokenizer.pad_token_id] = -100 # this is for T5 models
    inputs["labels"] = labels["input_ids"]
    outputs = model(**inputs)
    log_softmax_logits = F.log_softmax(outputs.logits, dim=-1)
    nll_loss_per_token = F.nll_loss(log_softmax_logits.view(-1, log_softmax_logits.size(-1)), inputs["labels"].view(-1), ignore_index=-100, reduction='none')
    nll_loss_per_token = nll_loss_per_token.view(inputs["labels"].size(0), -1)
    mask = (inputs["labels"] != -100).float()
    nll_loss_per_sample = (nll_loss_per_token * mask).sum(dim=1) / mask.sum(dim=1) # the sample-wise NLL is used for training retriever.
    
    loss = model(**inputs).loss
    loss.backward()
    optimizer.step()

    return loss.detach().item(), nll_loss_per_sample.detach()

def inference(model, batch):
    model.eval()
    with torch.no_grad():
        inputs, outputs = batch
        inputs = tokenizer(inputs, padding='max_length', truncation=True, max_length=args.max_length, return_tensors='pt').to(device)
        labels = tokenizer(outputs, padding='max_length', truncation=True, max_length=args.max_length, return_tensors='pt').to(device)
        labels["input_ids"][labels["input_ids"] == tokenizer.pad_token_id] = -100 # this is for T5 models
        inputs["labels"] = labels["input_ids"]
        outputs = model(**inputs)
        log_softmax_logits = F.log_softmax(outputs.logits, dim=-1)
        nll_loss_per_token = F.nll_loss(log_softmax_logits.view(-1, log_softmax_logits.size(-1)), inputs["labels"].view(-1), ignore_index=-100, reduction='none')
        nll_loss_per_token = nll_loss_per_token.view(inputs["labels"].size(0), -1)
        mask = (inputs["labels"] != -100).float()
        nll_loss_per_sample = (nll_loss_per_token * mask).sum(dim=1) / mask.sum(dim=1) # the sample-wise NLL is used for training retriever.

        return nll_loss_per_sample.detach()

def eval_loss(retriever, model, input_dataloader):
    model.eval()
    pbar = input_dataloader
    if args.verbose: pbar = tqdm(input_dataloader)
    with torch.no_grad():
        total_loss = []
        for batch_idx in pbar:
            batch_query = [passage_list[i] for i in batch_idx]
            _, top_k_indices = retriever.retrieve_no_grad(batch_query, topk=args.eval_neighbors)
            mask = torch.ones_like(top_k_indices).bool().to(device)
            # mask = top_k_indices != batch_idx.unsqueeze(1).to(device)
            context_tensor = string_tensor(top_k_indices, prototype_list, mask) # batch x topk
            
            input_list, output_list = generate_rag_input_output_list(
                lists=lists,
                functions=(prompt_function, input_function, output_function),
                indices=batch_idx, context_tensor=context_tensor)
            
            tmp_dataloader = torch.utils.data.DataLoader(Seq2SeqDataset(input_list, output_list), 
                    batch_size=batch_size) # so might lead to extra large batch, thus, re-mini-batch the dataloarder

            for inputs, outputs in tmp_dataloader:
                inputs = tokenizer(inputs, padding='max_length', truncation=True, max_length=args.max_length, return_tensors='pt').to(device)
                labels = tokenizer(outputs, padding='max_length', truncation=True, max_length=args.max_length, return_tensors='pt').to(device)
                labels["input_ids"][labels["input_ids"] == tokenizer.pad_token_id] = -100 # this is for T5 models
                inputs["labels"] = labels["input_ids"]
                loss = model(**inputs).loss
                total_loss.append(loss.item())
        return np.mean(total_loss)

def generate(retriever, model, input_dataloader):
    model.eval()
    pbar = input_dataloader
    if args.verbose: pbar = tqdm(input_dataloader)
    with torch.no_grad():
        total_output = []
        for batch_idx in pbar:
            batch_query = [passage_list[i] for i in batch_idx]
            _, top_k_indices = retriever.retrieve_no_grad(batch_query, topk=args.eval_neighbors)
            mask = torch.ones_like(top_k_indices).bool().to(device)
            # mask = top_k_indices != batch_idx.unsqueeze(1).to(device)
            context_tensor = string_tensor(top_k_indices, prototype_list, mask) # batch x topk
            
            input_list, output_list = generate_rag_input_output_list(
                lists=lists,
                functions=(prompt_function, input_function, output_function),
                indices=batch_idx, context_tensor=context_tensor)
            
            tmp_dataloader = torch.utils.data.DataLoader(Seq2SeqDataset(input_list, output_list), 
                    batch_size=batch_size) # so might lead to extra large batch, thus, re-mini-batch the dataloarder

            for inputs, _ in tmp_dataloader:
                inputs = tokenizer(inputs, padding='max_length', truncation=True, max_length=args.max_length, return_tensors='pt').to(device)
                output_sequences = model.generate(**inputs, do_sample=False)
                output_sequences = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
                total_output += output_sequences
    return total_output

def evaluate_acc(groundtruth, output, label_set):
    groundtruth = [x.lower() for x in groundtruth]
    output = [x.lower() for x in output]
    label_set = [x.lower() for x in label_set]
    if '2l' in args.output_mode:
        rev = True # in this case the pool will be repeated first, and then the true label
    else:
        rev = False

    total = 0
    nonsense = 0
    predictions = []
    for i in output:
        prediction = matching_order(label_set, i, rev=rev) # sort the appearance of the labels in the output
        if len(prediction) < 1:
            prediction = ['None']
            nonsense += 1
        predictions.append(prediction[0])
        total += 1
    acc = accuracy(predictions, groundtruth)
    nonsense = nonsense/total
    return acc, nonsense

model_id = args.model_id
epochs, batch_size, lr_llm, lr_retriever = args.epochs, args.batch_size, args.lr_llm, args.lr_retriever

dataset_folder = "processed_data"
split_folder = f"raw_data/{args.dataset}/splits"
dataset, input_mode, output_mode, retrieved_neighbors, train_neighbors = args.dataset, args.input_mode, args.output_mode, args.retrieved_neighbors, args.train_neighbors
input_label_pool = args.input_label_pool
num_neighbors = args.num_neighbors
assert train_neighbors <= retrieved_neighbors and train_neighbors > 0

print(f"LLM: {model_id}, Dataset: {dataset}")
print(f"Input mode: {input_mode}, Output mode: {output_mode}, Input label pool: {input_label_pool}, Num neighbors: {num_neighbors}")
print(f"Epochs: {epochs}, Batch size: {batch_size}, lr LLM: {lr_llm}, lr Retriever: {lr_retriever}")
print(f"Retrieved neighbors: {retrieved_neighbors}, Training neighbors: {args.train_neighbors}, Eval neighbors: {args.eval_neighbors}")
print(f"Retriever: {args.retriever_id}, Train retriever: {args.train_retriever}, Eval per steps: {args.eval_per_steps}, Patience: {args.patience}")

if input_label_pool != 0:
    input_mode = str(input_label_pool)+ 'l_' + input_mode
else:
    input_mode = '0l_' + input_mode

tokenizer = AutoTokenizer.from_pretrained(model_id)

if dataset == 'ogbn-products':
    from templates.products_templates import get_template
else:
    from templates.citation_templates import get_template

prompt_function, input_function, output_function = get_template('c', input_mode, output_mode) # 'c' is for classification task
title_list, content_list, label_list, neighbors_list, rationale_list, gpt_list = load_meta_data_lists(dataset_folder, dataset, input_mode, output_mode)
label_set = set(label_list)
neighbors_list = [x[:num_neighbors] for x in neighbors_list] # only use the top num_neighbors neighbors
passage_list = gpt_list # in this case we use the gpt's response as the input of the retriever.

if dataset == 'cora' or dataset == 'pubmed':
    seeds = list(range(5))
else:
    seeds = [0,0,0,0,0] # for OGB dataset, no seed used for split, so testing them on the given split for 5 times

accs = []
for seed in seeds:
    print(f"\nSeed: {seed}")
    # model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=torch.bfloat16,)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)
    optimizer = AdamW(model.parameters(), lr=lr_llm)
    total_parameters = sum(p.numel() for p in model.parameters())
    print(f"Total number of parameters of LLM: {total_parameters}")

    retriever = Retriever(dataset_folder=dataset_folder,
                            corpus_name=dataset,
                            encoder_name=args.retriever_id,
                            lr=args.lr_retriever,
                            seed=seed,
                            device=device)

    if dataset == 'cora' or dataset == 'pubmed':
        train_idx, valid_idx, test_idx = load_split_idx(split_folder, dataset, seed)
    elif dataset == 'ogbn-arxiv':
        train_idx, valid_idx, test_idx = load_split_idx_ogb(split_folder, dataset)
    elif dataset == 'ogbn-products':
        train_idx, valid_idx, test_idx = load_split_idx_tape(split_folder, dataset)
    print("Generating input and output lists...")

    label_and_prob_list, raw_label_and_prob_list = load_label_and_prob_list(dataset_folder, dataset, seed)
    prototype_list = load_prototype_list(dataset_folder, dataset, seed)
    lists = (label_set, title_list, content_list, label_list, label_and_prob_list, neighbors_list, rationale_list, gpt_list, raw_label_and_prob_list)

    train_dataloader = torch.utils.data.DataLoader(train_idx, shuffle=True, batch_size=batch_size)
    valid_dataloader = torch.utils.data.DataLoader(valid_idx, batch_size=batch_size)
    test_dataloader = torch.utils.data.DataLoader(test_idx, batch_size=batch_size)
    
    model_save_dir = f'finetuned_models/{model_id}_{dataset}_{seed}_{input_mode}_{output_mode}'
    retriever_save_dir = f'finetuned_models/{args.retriever_id}_{dataset}_{seed}_{input_mode}_{output_mode}'
    early_stopper = EarlyStopping(patience=args.patience)

    total_steps_per_epoch = len(train_dataloader)
    total_steps = total_steps_per_epoch * epochs
    current_step = 0
    early_stop_flag = False
    show_template = True
    for epoch in range(epochs):
        pbar = train_dataloader
        if args.verbose: pbar = tqdm(train_dataloader)

        for batch_idx in pbar:
            batch_query = [passage_list[i] for i in batch_idx]

            if args.train_retriever:
                retriever.model.train()
                retriever.optimizer.zero_grad()
            else:
                retriever.model.eval()

            batch_query_encoding = retriever.model.tokenizer(batch_query, padding=True, truncation=True, return_tensors='pt').to(device)
            batch_query_emb = retriever.model(batch_query_encoding)['sentence_embedding']
            top_k_scores, top_k_indices = retriever.retrieve(batch_query_emb, topk=args.retrieved_neighbors)

            # top_k_indices_mask = top_k_indices != batch_idx.unsqueeze(1).to(device) # to set the NULL index to the target node itself
            top_k_indices_mask = torch.ones_like(top_k_indices).bool().to(device) # we use the pooled text, so no need to mask the index whose text is the same as the target node

            training_LLM_indices = top_k_indices[:,:args.train_neighbors] # documents used to train the LLM
            inference_LLM_indices = top_k_indices[:,args.train_neighbors:] # documents used for LLM inference only, all the docs together are used for training the retriever
            training_LLM_mask = top_k_indices_mask[:,:args.train_neighbors]
            inference_LLM_mask = top_k_indices_mask[:,args.train_neighbors:]

            context_tensor = string_tensor(training_LLM_indices, prototype_list, training_LLM_mask)
            input_list, output_list = generate_rag_input_output_list(
                lists=lists,
                functions=(prompt_function, input_function, output_function),
                indices=batch_idx, context_tensor=context_tensor)
            training_LLM_dataloader = torch.utils.data.DataLoader(Seq2SeqDataset(input_list, output_list), 
                        batch_size=batch_size) # so might lead to extra large batch, thus, re-mini-batch the dataloarder
            if show_template:
                print(f"\n\nTemplate Example: \n{input_list[0]}\n{output_list[0]}\n")
                show_template = False
            
            context_tensor = string_tensor(inference_LLM_indices, prototype_list, inference_LLM_mask)
            input_list, output_list = generate_rag_input_output_list(
                lists=lists,
                functions=(prompt_function, input_function, output_function),
                indices=batch_idx, context_tensor=context_tensor)
            inference_LLM_dataloader = torch.utils.data.DataLoader(Seq2SeqDataset(input_list, output_list), 
                        batch_size=batch_size) # so might lead to extra large batch, thus, re-mini-batch the dataloarder
            
            total_loss, llm_scores_from_training, llm_scores_from_inference = 0, torch.tensor([], device=device), torch.tensor([], device=device) # the NLL loss per sample is used for training the retriever
            
            for batch in training_LLM_dataloader:
                batch_loss, batch_nll_per_sample = train(model, optimizer, batch)
                total_loss += batch_loss
                llm_scores_from_training = torch.cat((llm_scores_from_training, batch_nll_per_sample))
            
            if args.train_retriever:
                for batch in inference_LLM_dataloader:
                    batch_nll_per_sample = inference(model, batch)
                    llm_scores_from_inference = torch.cat((llm_scores_from_inference, batch_nll_per_sample))
                
                llm_scores_from_training = llm_scores_from_training.view(batch_idx.shape[0], -1)
                llm_scores_from_inference = llm_scores_from_inference.view(batch_idx.shape[0], -1)
                llm_scores = torch.cat((llm_scores_from_training,
                                        llm_scores_from_inference), dim=1)

                score_groundtruth = F.softmax(-llm_scores.view(batch_idx.shape[0], -1), dim=-1) # batch x retrieved neighbors
                score_retriever = F.log_softmax(top_k_scores, dim=-1)
                retriever_loss = F.kl_div(score_retriever, score_groundtruth, reduction='batchmean')
                retriever_loss.backward()
                retriever.optimizer.step()
                retriever_loss = retriever_loss.detach().item()
            else:
                retriever_loss = 'N/A'
            llm_train_loss = total_loss/len(training_LLM_dataloader)
            if args.verbose:
                pbar.set_postfix({'llm_train_loss': llm_train_loss, 'retriever_train_loss': retriever_loss})
            else:
                if current_step % 10 == 0:
                    print(f"Step {current_step}, llm train loss: {llm_train_loss:.3f}, retriever train loss: {retriever_loss:.3f}")

            if current_step > 0 and current_step % args.eval_per_steps == 0:
                valid_loss = eval_loss(retriever, model, valid_dataloader)
                epoch_progress = (current_step + 1) / total_steps_per_epoch
                if retriever_loss == 'N/A':
                    print(f"""Epoch {epoch_progress:.3f}, llm train loss: {llm_train_loss:.3f}, llm valid loss: {valid_loss:.3f}""")
                else:
                    print(f"""Epoch {epoch_progress:.3f}, llm train loss: {llm_train_loss:.3f}, llm valid loss: {valid_loss:.3f}, retriever train loss: {retriever_loss:.3f}""")
                early_stop_flag, early_stop_verbose = early_stopper.step(valid_loss, [tokenizer, model, retriever], epoch_progress, [model_save_dir, retriever_save_dir])
                print(early_stop_verbose)
                if early_stop_flag:
                    break
            current_step += 1

            del retriever_loss
            torch.cuda.empty_cache()
        if early_stop_flag:
            break
    
    # Load the best model
    if early_stop_flag:
        model = AutoModelForSeq2SeqLM.from_pretrained(model_save_dir).to(device)
        retriever.model = SentenceTransformer(retriever_save_dir).to(device)
    
    generated_text = generate(retriever, model, test_dataloader)
    groundtruth = [x for x in [label_list[i] for i in test_idx] for _ in range(args.eval_neighbors)] # Every label is repeated for args.eval_neighbors times.
    test_acc, test_nonsense = evaluate_acc(groundtruth, generated_text, label_set)
    print(f"Test acc: {test_acc:.4f}, nonsense: {test_nonsense:.4f}")
    accs.append(test_acc)
print(f"Average test acc: {np.mean(accs):.4f}, std: {np.std(accs):.4f}")